#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from .TTransport import *

import os
import socket
import sys
import warnings

if sys.version_info[0] >= 3:
    from io import BytesIO as StringIO
    from urllib import parse
    from http import client
    urlparse = parse
    urllib = parse
    httplib = client
else:
    from cStringIO import StringIO
    import urlparse
    import httplib
    import urllib

class THttpClient(TTransportBase):

    """Http implementation of TTransport base."""

    def __init__(self, uri_or_host, port=None, path=None):
        """THttpClient supports two different types constructor parameters.

        THttpClient(host, port, path) - deprecated
        THttpClient(uri)

        Only the second supports https."""

        if port is not None:
            warnings.warn(
                "Please use the THttpClient('http://host:port/path') syntax",
                DeprecationWarning,
                stacklevel=2)
            self.host = uri_or_host
            self.http_host = self.host
            self.port = port
            assert path
            self.path = path
            self.scheme = 'http'
        else:
            parsed = urlparse.urlparse(uri_or_host)
            self.scheme = parsed.scheme
            assert self.scheme in ('http', 'https')
            if self.scheme == 'http':
                self.port = parsed.port or httplib.HTTP_PORT
            elif self.scheme == 'https':
                self.port = parsed.port or httplib.HTTPS_PORT
            self.host = parsed.hostname
            self.http_host = parsed.netloc
            self.path = parsed.path
            if parsed.query:
                self.path += '?%s' % parsed.query
        self.__wbuf = StringIO()
        self.__http = None
        self.__timeout = None
        self.__custom_headers = None

    def open(self):
        if self.scheme == 'http':
            self.__http = httplib.HTTPConnection(self.host, self.port,
                                                 timeout=self.__timeout)
        else:
            self.__http = httplib.HTTPSConnection(self.host, self.port,
                                                  timeout=self.__timeout)

    def close(self):
        self.__http.close()
        self.__http = None

    def isOpen(self):
        return self.__http is not None

    def setTimeout(self, ms):
        if ms is None:
            self.__timeout = None
        else:
            self.__timeout = ms / 1000.0

    def setCustomHeaders(self, headers):
        self.__custom_headers = headers

    def setCustomHeader(self, name, value):
        if self.__custom_headers is None:
            self.__custom_headers = {}
        self.__custom_headers[name] = value

    def read(self, sz):
        return self.response.read(sz)

    def write(self, buf):
        self.__wbuf.write(buf)

    def flush(self):
        if self.isOpen():
            self.close()
        self.open()

        # Pull data out of buffer
        data = self.__wbuf.getvalue()
        self.__wbuf = StringIO()

        # HTTP request
        self.__http.putrequest('POST', self.path, skip_host=True)

        if not self.__custom_headers or 'Host' not in self.__custom_headers:
            self.__http.putheader('Host', self.http_host)

        self.__http.putheader('Content-Type', 'application/x-thrift')
        self.__http.putheader('Content-Length', str(len(data)))

        if not self.__custom_headers or 'User-Agent' not in \
                self.__custom_headers:
            user_agent = 'Python/THttpClient'
            script = os.path.basename(sys.argv[0])
            if script:
                user_agent = '%s (%s)' % (user_agent, urllib.quote(script))
            self.__http.putheader('User-Agent', user_agent)

        if self.__custom_headers:
            if sys.version_info[0] >= 3:
                custom_headers_iter = self.__custom_headers.items()
            else:
                custom_headers_iter = self.__custom_headers.items()
            for key, val in custom_headers_iter:
                self.__http.putheader(key, val)

        try:
            self.__http.endheaders()

            # Write payload
            self.__http.send(data)
        except socket.gaierror as e:
            raise TTransportException(TTransportException.NOT_OPEN, str(e))
        except Exception as e:
            raise TTransportException(TTransportException.UNKNOWN, str(e))


        # Get reply to flush the request
        self.response = self.__http.getresponse()
        self.code = self.response.status
        self.headers = self.response.getheaders()
